/*
 * UTF - 8
 * 尝试在该WiFi系统中添加对各个信道状态的追踪器。
 * 一个AP与多个STA的WiFi系统，其中一个STA为MLD，其余均仅使用一个channel，架构图如下所示：
 *      --- 5180MHz --- STA0，STA1
 *      --- 2412MHz --- STA0，STA2
 * AP   --- 5945MHz --- STA0
 * 在STA0上开启了emlsr功能
 * 在所有STA上开启了wifi探针机制
 * 最后统计传输速率的部分仅统计了服务端一个UDP端口的接收速率，如果想要统计所有端口的接收速率，
 * 需要修改下述代码为：
 *
 * - serverApp = server.Install(serverNodes.Get(0));
 *
 * + serverApp.Add(server.Install(serverNodes.Get(0)));
 *
 * 最后输出的pcap文件名的第一个数字是节点的序号，该序号为节点被创造(creat)的顺序
 * 至于每个pcap文件对应于哪一个信道，需要打开pcap文件后自行查阅记录的phy头部的radiotap.channel.freq字段
 * 
 * 对各个节点的info输出的数字为节点在自己所属的Container的序号，该序号为节点add进Container的顺序
 */
#include "ns3/boolean.h"
#include "ns3/callback.h"
#include "ns3/command-line.h"
#include "ns3/config.h"
#include "ns3/double.h"
#include "ns3/eht-phy.h"
#include "ns3/enum.h"
#include "ns3/frame-exchange-manager.h"
#include "ns3/internet-stack-helper.h"
#include "ns3/ipv4-address-helper.h"
#include "ns3/log.h"
#include "ns3/mobility-helper.h"
#include "ns3/multi-model-spectrum-channel.h"
#include "ns3/on-off-helper.h"
#include "ns3/packet-sink-helper.h"
#include "ns3/packet-sink.h"
#include "ns3/rng-seed-manager.h"
#include "ns3/spectrum-wifi-helper.h"
#include "ns3/ssid.h"
#include "ns3/string.h"
#include "ns3/udp-client-server-helper.h"
#include "ns3/uinteger.h"
#include "ns3/wifi-acknowledgment.h"
#include "ns3/wifi-mac.h"
#include "ns3/wifi-net-device.h"
#include "ns3/wifi-phy-state-helper.h"
#include "ns3/wifi-phy.h"

#include <array>
#include <chrono>
#include <functional>
#include <iomanip>
#include <iostream>
#include <numeric>
#include <sstream>

using namespace ns3;

NS_LOG_COMPONENT_DEFINE("EhtStaApLinkLink");

/**
 * @brief Writes a log string to the log file.
 *
 * @param logStr The log string to be written.
 */
void
writeLog(const std::string logStr)
{
    auto currentTime = std::chrono::system_clock::now();
    std::time_t currentTime_t = std::chrono::system_clock::to_time_t(currentTime);
    std::tm* currentTime_tm = std::localtime(&currentTime_t);
    auto milliseconds =
        std::chrono::duration_cast<std::chrono::milliseconds>(currentTime.time_since_epoch())
            .count() %
        1000;
    std::stringstream ss;
    char buffer[80];
    std::strftime(buffer, sizeof(buffer), "%H-%M-%S", currentTime_tm);
    ss << buffer << "." << std::setfill('0') << std::setw(3) << milliseconds;
    std::string formattedTime = ss.str();

    std::ofstream outputFile;
    outputFile.open("multiStaTraceChannel.cc.log", std::ios::app);

    if (outputFile.is_open())
    {
        outputFile << "[" << formattedTime << "] " << logStr;
        outputFile.close();
        // std::cout << "[" << formattedTime << "]"<< "Text has been written to the file." <<
        // std::endl;
    }
    else
    {
        std::cout << "[" << formattedTime << "]"
                  << "Failed to open or create the file!" << std::endl;
    }
}

void
WifiStateTrace(std::string context, Time start, Time duration, WifiPhyState state)
{
    std::stringstream nowStr;
    Time now = Simulator::Now();
    nowStr << "[" << now.GetNanoSeconds() << "] ";
    std::stringstream startStr;
    startStr << "Start [" << start.GetNanoSeconds() << "] ";
    std::stringstream durationStr;
    durationStr << "Duration [" << duration.GetNanoSeconds() << "] ";
    std::string stateStr;
    switch (state)
    {
    case IDLE:
        stateStr = "IDLE";
        break;
    case CCA_BUSY:
        stateStr = "CCA_BUSY";
        break;
    case TX:
        stateStr = "TX";
        break;
    case RX:
        stateStr = "RX";
        break;
    case SWITCHING:
        stateStr = "SWITCHING";
        break;
    case SLEEP:
        stateStr = "SLEEP";
        break;
    case OFF:
        stateStr = "OFF";
        break;
    default:
        NS_FATAL_ERROR("Invalid state");
        stateStr = "INVALID";
        break;
    }
    writeLog("[" + context + "] " + nowStr.str() + startStr.str() + durationStr.str() + stateStr +
             "\n");
}

/**
 * Connects the trace for the state of each WifiPhy to the given context.
 *
 * @param wifi_dev The WifiNetDevice to connect the trace to.
 * @param context The context string to be appended to the trace name.
 */
void
ConnectContextTrace2Mac(Ptr<WifiNetDevice> wifi_dev, std::string context)
{
    Ptr<WifiMac> wifi_mac = wifi_dev->GetMac();
    int phyid = 0;
    for (Ptr<WifiPhy> wifi_phy : wifi_dev->GetPhys())
    {
        Ptr<WifiPhyStateHelper> wifi_phy_state = wifi_phy->GetState();
        wifi_phy_state->TraceConnect("State",
                                     context + std::to_string(phyid),
                                     MakeCallback(&WifiStateTrace));
        phyid++;
    }
}

/**
 * \param udp true if UDP is used, false if TCP is used
 * \param serverApp a container of server applications
 * \param payloadSize the size in bytes of the packets
 * \return the bytes received by each server application
 */
std::vector<uint64_t>
StaApLinkLinkGetRxBytes(bool udp, const ApplicationContainer& serverApp, uint32_t payloadSize)
{
    std::vector<uint64_t> rxBytes(serverApp.GetN(), 0);
    if (udp)
    {
        for (uint32_t i = 0; i < serverApp.GetN(); i++)
        {
            rxBytes[i] = payloadSize * DynamicCast<UdpServer>(serverApp.Get(i))->GetReceived();
        }
    }
    else
    {
        for (uint32_t i = 0; i < serverApp.GetN(); i++)
        {
            rxBytes[i] = DynamicCast<PacketSink>(serverApp.Get(i))->GetTotalRx();
        }
    }
    return rxBytes;
};

int
main(int argc, char* argv[])
{
    writeLog("Hello Log");

    double simulationTime = 3; // seconds
    double distance = 1.0;     // meters

    int mcs = 13;
    int gi = 3200;              // ns
    uint32_t payloadSize = 700; // bytes
    uint64_t udpMaxPacketNumber = 4294967295U;

    // struct channelInfo
    // {
    //     int channelNumber;
    //     int channelWidth;
    //     double channelFrequency;
    //     int primaryIndex;
    // };
    // std::vector<channelInfo> channels = {{36, 20, 5, 0}, {40, 20, 5, 0}}; // channel info vector
    // std::vector<std::string> channelStr;
    // for (auto channel : channels)
    // {
    //     int channelWidth = channel.channelWidth;
    // }

    uint8_t nLinks = 3;
    std::string channelLink0 =
        "{36, 20, BAND_5GHZ, 0}"; // {channel number, channel width (MHz), PHY band, primary20
                                  // index} set to wifi-phy.cc and from
                                  // wifi-phy-operating-channel.cc
    std::string channelLink1 =
        "{1, 20, BAND_2_4GHZ, 0}"; // {channel number, channel width (MHz), PHY band, primary20
                                   // index} set to wifi-phy.cc and from
                                   // wifi-phy-operating-channel.cc

    std::string channelLink2 = "{13, 20, BAND_6GHZ, 0}";

    NodeContainer staNodes;
    staNodes.Create(1);
    NodeContainer apNode;
    apNode.Create(1);

    NetDeviceContainer staDevices;
    NetDeviceContainer apDevice;
    WifiMacHelper wifiMac;
    WifiHelper wifiHelper;

    wifiHelper.SetStandard(WIFI_STANDARD_80211be);
    std::vector<FrequencyRange> freqRanges;

    // Link 0
    freqRanges.push_back(WIFI_SPECTRUM_5_GHZ);
    uint64_t nonHtRefRateMbps = EhtPhy::GetNonHtReferenceRate(mcs) / 1e6;
    std::string ctrlRateStr = "OfdmRate" + std::to_string(nonHtRefRateMbps) + "Mbps";
    wifiHelper.SetRemoteStationManager((uint8_t)0,
                                       "ns3::ConstantRateWifiManager",
                                       "DataMode",
                                       StringValue("EhtMcs" + std::to_string(mcs)),
                                       "ControlMode",
                                       StringValue(ctrlRateStr));

    // Link 1
    freqRanges.push_back(WIFI_SPECTRUM_2_4_GHZ);
    Config::SetDefault("ns3::LogDistancePropagationLossModel::ReferenceLoss", DoubleValue(40));
    ctrlRateStr = "ErpOfdmRate" + std::to_string(nonHtRefRateMbps) + "Mbps";
    wifiHelper.SetRemoteStationManager((uint8_t)1,
                                       "ns3::ConstantRateWifiManager",
                                       "DataMode",
                                       StringValue("EhtMcs" + std::to_string(mcs)),
                                       "ControlMode",
                                       StringValue(ctrlRateStr));

    // Link 2
    freqRanges.push_back(WIFI_SPECTRUM_6_GHZ);
    Config::SetDefault("ns3::LogDistancePropagationLossModel::ReferenceLoss", DoubleValue(48));
    wifiHelper.SetRemoteStationManager((uint8_t)2,
                                       "ns3::ConstantRateWifiManager",
                                       "DataMode",
                                       StringValue("EhtMcs" + std::to_string(mcs)),
                                       "ControlMode",
                                       StringValue("EhtMcs" + std::to_string(mcs)));

    Ssid ssid = Ssid("ns3-80211be");

    // Link0 : Channel
    Ptr<MultiModelSpectrumChannel> spectrumChannelLink0 = CreateObject<MultiModelSpectrumChannel>();
    Ptr<LogDistancePropagationLossModel> lossModel =
        CreateObject<LogDistancePropagationLossModel>();
    spectrumChannelLink0->AddPropagationLossModel(lossModel);

    // Link1 : Channel
    Ptr<MultiModelSpectrumChannel> spectrumChannelLink1 = CreateObject<MultiModelSpectrumChannel>();
    spectrumChannelLink1->AddPropagationLossModel(lossModel);

    // Link2 : Channel
    Ptr<MultiModelSpectrumChannel> spectrumChannelLink2 = CreateObject<MultiModelSpectrumChannel>();
    spectrumChannelLink2->AddPropagationLossModel(lossModel);

    SpectrumWifiPhyHelper phy(nLinks);
    phy.SetPcapDataLinkType(WifiPhyHelper::DLT_IEEE802_11_RADIO);
    phy.Set("ChannelSwitchDelay", TimeValue(MicroSeconds(100)));
    phy.AddChannel(spectrumChannelLink0, freqRanges[0]);
    phy.AddChannel(spectrumChannelLink1, freqRanges[1]);
    phy.AddChannel(spectrumChannelLink2, freqRanges[2]);

    wifiMac.SetType("ns3::StaWifiMac",
                    "Ssid",
                    SsidValue(ssid),
                    "ActiveProbing",
                    BooleanValue(true));
    phy.Set(0, "ChannelSettings", StringValue(channelLink0));
    phy.Set(1, "ChannelSettings", StringValue(channelLink1));
    phy.Set(2, "ChannelSettings", StringValue(channelLink2));

    // emlsr ，启动 !
    if (true)
    {
        wifiHelper.ConfigEhtOptions("EmlsrActivated", BooleanValue(true));
        std::string emlsrLinks = "0,1,2";
        uint16_t paddingDelayUsec =
            128; // Possible values are 0 us, 32 us, 64 us, 128 us or 256 us.
        uint16_t transitionDelayUsec =
            64; // "Possible values are 0 us, 16 us, 32 us, 64 us, 128 us or 256 us.",
        bool switchAuxPhy =
            true; // 这个值在src/wifi/model/eht/default-emlsr-manager.cc中默认为true，但我并不理解他的作用
        wifiMac.SetEmlsrManager("ns3::DefaultEmlsrManager",
                                "EmlsrLinkSet",
                                StringValue(emlsrLinks),
                                "EmlsrPaddingDelay",
                                TimeValue(MicroSeconds(paddingDelayUsec)),
                                "EmlsrTransitionDelay",
                                TimeValue(MicroSeconds(transitionDelayUsec)),
                                "SwitchAuxPhy",
                                BooleanValue(switchAuxPhy));
    }
    staDevices = wifiHelper.Install(phy, wifiMac, staNodes);

    if (true)
    {
        NodeContainer tempNode;
        tempNode.Create(2);

        // Link 0
        SpectrumWifiPhyHelper tempphyLink0(1);
        tempphyLink0.SetPcapDataLinkType(WifiPhyHelper::DLT_IEEE802_11_RADIO);
        tempphyLink0.AddChannel(spectrumChannelLink0, freqRanges[0]);
        WifiHelper tempwifiHelperLink0;
        tempwifiHelperLink0.SetStandard(WIFI_STANDARD_80211be);
        int linkid = 0;
        uint64_t nonHtRefRateMbps = EhtPhy::GetNonHtReferenceRate(mcs) / 1e6;
        std::string ctrlRateStr = "OfdmRate" + std::to_string(nonHtRefRateMbps) + "Mbps";
        tempwifiHelperLink0.SetRemoteStationManager(linkid,
                                                    "ns3::ConstantRateWifiManager",
                                                    "DataMode",
                                                    StringValue("EhtMcs" + std::to_string(mcs)),
                                                    "ControlMode",
                                                    StringValue(ctrlRateStr));
        tempphyLink0.Set(0, "ChannelSettings", StringValue(channelLink0));
        staDevices.Add(tempwifiHelperLink0.Install(tempphyLink0, wifiMac, tempNode.Get(0)));
        staNodes.Add(tempNode.Get(0));

        // Link 1
        SpectrumWifiPhyHelper tempphyLink1(1);
        tempphyLink1.SetPcapDataLinkType(WifiPhyHelper::DLT_IEEE802_11_RADIO);
        tempphyLink1.AddChannel(spectrumChannelLink1, freqRanges[1]);
        WifiHelper tempwifiHelperLink1;
        tempwifiHelperLink1.SetStandard(WIFI_STANDARD_80211be);
        ctrlRateStr = "ErpOfdmRate" + std::to_string(nonHtRefRateMbps) + "Mbps";
        tempwifiHelperLink1.SetRemoteStationManager((uint8_t)0,
                                                    "ns3::ConstantRateWifiManager",
                                                    "DataMode",
                                                    StringValue("EhtMcs" + std::to_string(mcs)),
                                                    "ControlMode",
                                                    StringValue(ctrlRateStr));
        tempphyLink1.Set(0, "ChannelSettings", StringValue(channelLink1));
        staDevices.Add(tempwifiHelperLink1.Install(tempphyLink1, wifiMac, tempNode.Get(1)));

        staNodes.Add(tempNode.Get(1));
    }

    // ap

    wifiMac.SetType("ns3::ApWifiMac",
                    "Ssid",
                    SsidValue(ssid),
                    "EnableBeaconJitter",
                    BooleanValue(false));
    apDevice = wifiHelper.Install(phy, wifiMac, apNode);
    // Config::ConnectWithoutContextFailSafe(
    //     "/NodeList/*/DeviceList/*/$ns3::WifiNetDevice/Phys/*/State/State",
    //     MakeCallback(&WifiStateTrace));
    Ptr<WifiNetDevice> wifi_dev = DynamicCast<WifiNetDevice>(apDevice.Get(0));
    ConnectContextTrace2Mac(wifi_dev, "wifi-ap");

    RngSeedManager::SetSeed(1);
    RngSeedManager::SetRun(1);
    int64_t streamNumber = 100;
    streamNumber += wifiHelper.AssignStreams(apDevice, streamNumber);
    streamNumber += wifiHelper.AssignStreams(
        staDevices,
        streamNumber); // 既为不同种类的device分配了不同的随机数种子，又保证了在不同的仿真实验中，随机数种子相同，从而保证了实验的可重复性。

    // Set guard interval and MPDU buffer size
    Config::Set("/NodeList/*/DeviceList/*/$ns3::WifiNetDevice/HeConfiguration/GuardInterval",
                TimeValue(NanoSeconds(gi))); // 保护间隔，即每个OFDM符号之间的间隔

    // mobility.
    // 设置移动模型，即设置移动节点的位置，此处仅是静态模型
    MobilityHelper mobility;
    Ptr<ListPositionAllocator> positionAlloc = CreateObject<ListPositionAllocator>();

    positionAlloc->Add(Vector(0.0, 0.0, 0.0));
    positionAlloc->Add(Vector(distance, 0.0, 0.0));
    mobility.SetPositionAllocator(positionAlloc);

    mobility.SetMobilityModel("ns3::ConstantPositionMobilityModel");

    mobility.Install(apNode);
    mobility.Install(staNodes);

    /* Internet stack*/
    InternetStackHelper stack;
    stack.Install(apNode);
    stack.Install(staNodes);

    Ipv4AddressHelper address;
    address.SetBase("192.168.1.0", "255.255.255.0");
    Ipv4InterfaceContainer staNodeInterfaces;
    Ipv4InterfaceContainer apNodeInterface;

    staNodeInterfaces = address.Assign(staDevices);
    apNodeInterface = address.Assign(apDevice);

    /* Setting applications */
    ApplicationContainer serverApp;
    auto serverNodes = apNode;
    Ipv4InterfaceContainer serverInterfaces;
    NodeContainer clientNodes;

    serverInterfaces.Add(apNodeInterface.Get(0));
    clientNodes.Add(staNodes);

    for (uint16_t port = 9; port < 19; port++)
    {
        // uint16_t port = 9;
        UdpServerHelper server(port);
        serverApp = server.Install(serverNodes.Get(
            0)); // 此处的serverApp仅存储了一个端口的信息，若想记录所有端口，将该行修改为serverApp.Add(server.Install(serverNodes.Get(0)));
        serverApp.Start(Seconds(0.0));
        serverApp.Stop(Seconds(simulationTime + 1));

        UdpClientHelper client(serverInterfaces.GetAddress(0), port);
        client.SetAttribute("MaxPackets", UintegerValue(udpMaxPacketNumber));
        client.SetAttribute("Interval", TimeValue(Time("0.00001"))); // packets/s
        client.SetAttribute("PacketSize", UintegerValue(payloadSize));
        ApplicationContainer clientApp = client.Install(clientNodes);
        clientApp.Start(Seconds(1.0));
        clientApp.Stop(Seconds(simulationTime + 1));
    }

    // 查看每个staDevice的所有mac地址
    std::cout << "staDevices info :" << std::endl;
    for (u_int16_t i = 0; i < staDevices.GetN(); i++)
    {
        const Ptr<WifiNetDevice> device = DynamicCast<WifiNetDevice>(staDevices.Get(i));
        Ptr<WifiMac> mac = device->GetMac();
        for (uint8_t linkId = 0; linkId < std::max<uint8_t>(device->GetNPhys(), 1); ++linkId)
        {
            auto fem = mac->GetFrameExchangeManager(linkId);
            std::cout << "staDevice " << i << " linkId " << std::to_string(linkId)
                      << " mac address: " << fem->GetAddress() << std::endl;
        }
    }
    // 查看每个staDevice的所有mac地址
    std::cout << "apDevices info :" << std::endl;
    for (u_int16_t i = 0; i < apDevice.GetN(); i++)
    {
        const Ptr<WifiNetDevice> device = DynamicCast<WifiNetDevice>(apDevice.Get(i));
        Ptr<WifiMac> mac = device->GetMac();
        for (uint8_t linkId = 0; linkId < std::max<uint8_t>(device->GetNPhys(), 1); ++linkId)
        {
            auto fem = mac->GetFrameExchangeManager(linkId);
            std::cout << "apDevice " << i << " linkId " << std::to_string(linkId)
                      << " mac address: " << fem->GetAddress() << std::endl;
        }
    }

    // ap Info
    std::cout << "ap ip : " << apNodeInterface.GetAddress(0) << std::endl;
    std::cout << "ap mac : " << apDevice.Get(0)->GetAddress() << std::endl;

    // sta Info
    for (size_t i = 0; i < staDevices.GetN(); i++)
    {
        std::cout << "sta " << i << " ip : " << staNodeInterfaces.GetAddress(i) << std::endl;
        std::cout << "sta " << i << " mac : " << staDevices.Get(i)->GetAddress() << std::endl;
    }

    std::string pcapFileName = "multiStaTraceChannel-ap";
    phy.EnablePcap(pcapFileName, apDevice, true);
    std::string pcapFileName1 = "multiStaTraceChannel-sta";
    phy.EnablePcap(pcapFileName1, staDevices, true);

    std::cout << "MCS value"
              << "\t\t"
              << "Channel width"
              << "\t\t"
              << "GI"
              << "\t\t\t"
              << "Throughput" << '\n';

    Simulator::Stop(Seconds(simulationTime + 1));
    Simulator::Run();

    // cumulative number of bytes received by each server application
    std::vector<uint64_t> cumulRxBytes(1, 0);
    cumulRxBytes = StaApLinkLinkGetRxBytes(true, serverApp, payloadSize);
    uint64_t rxBytes =
        std::accumulate(cumulRxBytes.cbegin(), cumulRxBytes.cend(), 0); // 向量中元素累加
    double throughput = (rxBytes * 8) / (simulationTime * 1000000.0);   // Mbit/s

    Simulator::Destroy();
    std::cout << mcs << "\t\t\t"
              << "20 MHz\t\t\t" << gi << " ns\t\t\t" << throughput << " Mbit/s" << std::endl;

    return 0;
}